from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import yaml

import numpy as np
from easydict import EasyDict as edict


config = edict()

config.OUTPUT_DIR = ''
config.LOG_DIR = ''
config.DATA_DIR = ''
config.GPUS = '0'
config.WORKERS = 4
config.PRINT_FREQ = 20
config.SEED = 0

# Cudnn related params
config.CUDNN = edict()
config.CUDNN.BENCHMARK = True
config.CUDNN.DETERMINISTIC = False
config.CUDNN.ENABLED = True

# pose_resnet related params
POSE_RESNET = edict()
POSE_RESNET.NUM_LAYERS = 50
POSE_RESNET.DECONV_WITH_BIAS = False
POSE_RESNET.NUM_DECONV_LAYERS = 3
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
POSE_RESNET.FINAL_CONV_KERNEL = 1
POSE_RESNET.TARGET_TYPE = 'gaussian'
POSE_RESNET.HEATMAP_SIZE = [64, 64]  # width * height, ex: 24 * 32
POSE_RESNET.SIGMA = 2
POSE_RESNET.RANDOM_DOWNSAMPLE = False

# pose_multi_resoluton_net related params
POSE_HIGH_RESOLUTION_NET = edict()
POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
POSE_HIGH_RESOLUTION_NET.TARGET_TYPE = 'gaussian'
POSE_HIGH_RESOLUTION_NET.HEATMAP_SIZE = [64, 64]  # width * height, ex: 192 * 256
POSE_HIGH_RESOLUTION_NET.SIGMA = 2


POSE_HIGH_RESOLUTION_NET.STAGE2 = edict()
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET.STAGE3 = edict()
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET.STAGE4 = edict()
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'

# pose_resnet related params
POSE_TCN = edict()
POSE_TCN.NUM_LAYERS = 50
POSE_TCN.DECONV_WITH_BIAS = False
POSE_TCN.NUM_DECONV_LAYERS = 3
POSE_TCN.NUM_DECONV_FILTERS = [256, 256, 256]
POSE_TCN.NUM_DECONV_KERNELS = [4, 4, 4]
POSE_TCN.FINAL_CONV_KERNEL = 1
POSE_TCN.TARGET_TYPE = 'coords'
POSE_TCN.HEATMAP_SIZE = [64, 64]  # width * height, ex: 24 * 32
POSE_TCN.SIGMA = 2

POSE_UNET = edict()
POSE_UNET.HEATMAP_SIZE = [64, 64]
POSE_UNET.TARGET_TYPE = 'gaussian'
POSE_UNET.SIGMA = 2

POSE_ENCODERDECODER = edict()
POSE_ENCODERDECODER.HEATMAP_SIZE = [64, 64]
POSE_ENCODERDECODER.TARGET_TYPE = 'gaussian'
POSE_ENCODERDECODER.SIGMA = 2

POSE_ASYM_ENCODERDECODER = edict()
POSE_ASYM_ENCODERDECODER.HEATMAP_SIZE = [64, 64]
POSE_ASYM_ENCODERDECODER.TARGET_TYPE = 'gaussian'
POSE_ASYM_ENCODERDECODER.SIGMA = 2

# pose_resnet related params
POSE_RESNET_APS = edict()
POSE_RESNET_APS.NUM_LAYERS = 50
POSE_RESNET_APS.DECONV_WITH_BIAS = False
POSE_RESNET_APS.NUM_DECONV_LAYERS = 3
POSE_RESNET_APS.NUM_DECONV_FILTERS = [256, 256, 256]
POSE_RESNET_APS.NUM_DECONV_KERNELS = [4, 4, 4]
POSE_RESNET_APS.FINAL_CONV_KERNEL = 1
POSE_RESNET_APS.TARGET_TYPE = 'gaussian'
POSE_RESNET_APS.HEATMAP_SIZE = [64, 64]  # width * height, ex: 24 * 32
POSE_RESNET_APS.SIGMA = 2
POSE_RESNET_APS.FILTER_SIZE = 5
POSE_RESNET_APS.PADDING_MODE = 'circular'
POSE_RESNET_APS.APS_CRITERION = 'l2'
POSE_RESNET_APS.DOWNSAMPLE_TYPE = 'aps'
POSE_RESNET_APS.UPSAMPLE_TYPE = 'aps'

# pose_multi_resoluton_net related params
POSE_HIGH_RESOLUTION_NET_APS = edict()
POSE_HIGH_RESOLUTION_NET_APS.PRETRAINED_LAYERS = ['*']
POSE_HIGH_RESOLUTION_NET_APS.STEM_INPLANES = 64
POSE_HIGH_RESOLUTION_NET_APS.FINAL_CONV_KERNEL = 1
POSE_HIGH_RESOLUTION_NET_APS.TARGET_TYPE = 'gaussian'
POSE_HIGH_RESOLUTION_NET_APS.HEATMAP_SIZE = [64, 64]  # width * height, ex: 192 * 256
POSE_HIGH_RESOLUTION_NET_APS.SIGMA = 2


POSE_HIGH_RESOLUTION_NET_APS.STAGE2 = edict()
POSE_HIGH_RESOLUTION_NET_APS.STAGE2.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET_APS.STAGE2.NUM_BRANCHES = 2
POSE_HIGH_RESOLUTION_NET_APS.STAGE2.NUM_BLOCKS = [4, 4]
POSE_HIGH_RESOLUTION_NET_APS.STAGE2.NUM_CHANNELS = [32, 64]
POSE_HIGH_RESOLUTION_NET_APS.STAGE2.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET_APS.STAGE2.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET_APS.STAGE3 = edict()
POSE_HIGH_RESOLUTION_NET_APS.STAGE3.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET_APS.STAGE3.NUM_BRANCHES = 3
POSE_HIGH_RESOLUTION_NET_APS.STAGE3.NUM_BLOCKS = [4, 4, 4]
POSE_HIGH_RESOLUTION_NET_APS.STAGE3.NUM_CHANNELS = [32, 64, 128]
POSE_HIGH_RESOLUTION_NET_APS.STAGE3.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET_APS.STAGE3.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET_APS.STAGE4 = edict()
POSE_HIGH_RESOLUTION_NET_APS.STAGE4.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET_APS.STAGE4.NUM_BRANCHES = 4
POSE_HIGH_RESOLUTION_NET_APS.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
POSE_HIGH_RESOLUTION_NET_APS.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
POSE_HIGH_RESOLUTION_NET_APS.STAGE4.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET_APS.STAGE4.FUSE_METHOD = 'SUM'

MODEL_EXTRAS = {
    'pose_resnet' : POSE_RESNET,
    'pose_hrnet' : POSE_HIGH_RESOLUTION_NET,
    'temp_tcn' : POSE_TCN,
    'pose_unet' : POSE_UNET,
    'pose_encoderdecoder' : POSE_ENCODERDECODER,
    'pose_asymmetrical_encodeco' : POSE_ASYM_ENCODERDECODER,
    'pose_resnet_aps': POSE_RESNET_APS,
}

# joints union
config.UNION_KEYPOINTS = []

# common params for NETWORK
config.MODEL = edict()
config.MODEL.NAME = 'pose_resnet'
config.MODEL.INIT_WEIGHTS = True
config.MODEL.PRETRAINED = ''
config.MODEL.NUM_JOINTS = 16
config.MODEL.IMAGE_SIZE = [256, 256]  # width * height, ex: 192 * 256
config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]

config.MODEL.STYLE = 'pytorch'

config.LOSS = edict()
config.LOSS.USE_TARGET_WEIGHT = True

# DATASET related params
trainset_config = edict()
trainset_config.DATASET = 'mpii'
trainset_config.SUBSET = 'train'
trainset_config.ROOT = ''
trainset_config.SCALE_FACTOR = 0
trainset_config.ROT_FACTOR = 0
trainset_config.FLIP = False

testset_config = edict()
testset_config.DATASET = 'mpii'
testset_config.SUBSET = 'valid'
testset_config.ROOT = ''

config.DATASET = edict()
config.DATASET.TRAIN_DATASET = [trainset_config]
config.DATASET.TEST_DATASET = testset_config
config.DATASET.COLOR_JITTER = False
config.DATASET.DATA_FORMAT = 'jpg'
config.DATASET.WINDOW_SIZE = 3

# train
config.TRAIN = edict()

config.TRAIN.LR_FACTOR = 0.1
config.TRAIN.LR_STEP = [90, 110]
config.TRAIN.LR = 0.001

config.TRAIN.OPTIMIZER = 'adam'
config.TRAIN.MOMENTUM = 0.9
config.TRAIN.WD = 0.0001
config.TRAIN.NESTEROV = False
config.TRAIN.GAMMA1 = 0.99
config.TRAIN.GAMMA2 = 0.0

config.TRAIN.BEGIN_EPOCH = 0
config.TRAIN.END_EPOCH = 140

config.TRAIN.RESUME = False
config.TRAIN.RESUME_PATH = ''
config.TRAIN.ON_SERVER_CLUSTER = False
config.TRAIN.CHECKPOINT = ''

config.TRAIN.BATCH_SIZE = 32
config.TRAIN.SHUFFLE = True

# testing
config.TEST = edict()

# size of images for each device
config.TEST.BATCH_SIZE = 32
# Test Model Epoch
config.TEST.FLIP_TEST = False
config.TEST.POST_PROCESS = False
config.TEST.SHIFT_HEATMAP = False

config.TEST.USE_GT_BBOX = False
# nms
config.TEST.SOFT_NMS = False
config.TEST.OKS_THRE = 0.5
config.TEST.IN_VIS_THRE = 0.0
config.TEST.COCO_BBOX_FILE = ''
config.TEST.BBOX_FILE = ''
config.TEST.BBOX_THRE = 1.0
# config.TEST.MATCH_IOU_THRE = 0.3
config.TEST.DETECTOR = 'fpn_dcn'
config.TEST.DETECTOR_DIR = ''
config.TEST.MODEL_FILE = ''
config.TEST.IMAGE_THRE = 0.0
config.TEST.NMS_THRE = 1.0
# config.TEST.FUSE_OUTPUT = True

# debug
config.DEBUG = edict()
config.DEBUG.DEBUG = False
config.DEBUG.SAVE_BATCH_IMAGES_GT = False
config.DEBUG.SAVE_BATCH_IMAGES_PRED = False
config.DEBUG.SAVE_HEATMAPS_GT = False
config.DEBUG.SAVE_HEATMAPS_PRED = False


def _update_dict(k, v):
    if k == 'DATASET':
        if 'MEAN' in v and v['MEAN']:
            v['MEAN'] = np.array([eval(x) if isinstance(x, str) else x
                                  for x in v['MEAN']])
        if 'STD' in v and v['STD']:
            v['STD'] = np.array([eval(x) if isinstance(x, str) else x
                                 for x in v['STD']])
    if k == 'MODEL':
        if 'EXTRA' in v and 'HEATMAP_SIZE' in v['EXTRA']:
            if isinstance(v['EXTRA']['HEATMAP_SIZE'], int):
                v['EXTRA']['HEATMAP_SIZE'] = np.array(
                    [v['EXTRA']['HEATMAP_SIZE'], v['EXTRA']['HEATMAP_SIZE']])
            else:
                v['EXTRA']['HEATMAP_SIZE'] = np.array(
                    v['EXTRA']['HEATMAP_SIZE'])
        if 'IMAGE_SIZE' in v:
            if isinstance(v['IMAGE_SIZE'], int):
                v['IMAGE_SIZE'] = np.array([v['IMAGE_SIZE'], v['IMAGE_SIZE']])
            else:
                v['IMAGE_SIZE'] = np.array(v['IMAGE_SIZE'])
    for vk, vv in v.items():
        if vk in config[k]:
            config[k][vk] = vv
        else:
            raise ValueError("{}.{} not exist in config.py".format(k, vk))


def update_config(config_file):
    exp_config = None
    with open(config_file) as f:
        exp_config = edict(yaml.load(f, Loader=yaml.FullLoader))
        for k, v in exp_config.items():
            if k in config:
                if isinstance(v, dict):
                    _update_dict(k, v)
                else:
                    if k == 'SCALES':
                        config[k][0] = (tuple(v))
                    else:
                        config[k] = v
            else:
                raise ValueError("{} not exist in config.py".format(k))


def gen_config(config_file):
    cfg = dict(config)
    for k, v in cfg.items():
        if isinstance(v, edict):
            cfg[k] = dict(v)

    with open(config_file, 'w') as f:
        yaml.dump(dict(cfg), f, default_flow_style=False)


def update_dir(model_dir, log_dir, data_dir):
    if model_dir:
        config.OUTPUT_DIR = model_dir

    if log_dir:
        config.LOG_DIR = log_dir

    if data_dir:
        config.DATA_DIR = data_dir

    config.DATASET.ROOT = os.path.join(
            config.DATA_DIR, config.DATASET.ROOT)

    config.TEST.COCO_BBOX_FILE = os.path.join(
            config.DATA_DIR, config.TEST.COCO_BBOX_FILE)

    config.MODEL.PRETRAINED = os.path.join(
            config.DATA_DIR, config.MODEL.PRETRAINED)


def get_model_name(cfg):
    name = cfg.MODEL.NAME
    full_name = cfg.MODEL.NAME
    extra = cfg.MODEL.EXTRA
    if name in ['pose_resnet']:
        name = '{model}_{num_layers}'.format(
            model=name,
            num_layers=extra.NUM_LAYERS)
        deconv_suffix = ''.join(
            'd{}'.format(num_filters)
            for num_filters in extra.NUM_DECONV_FILTERS)
        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
            height=cfg.MODEL.IMAGE_SIZE[1],
            width=cfg.MODEL.IMAGE_SIZE[0],
            name=name,
            deconv_suffix=deconv_suffix)
    elif name in ['pose_hrnet']:
        name = '{model}_{num_layers}'.format(
            model=name,
            num_layers=extra.STAGE2.NUM_CHANNELS[0])
        deconv_suffix = ''.join(
            'd{}'.format(num_filters)
            for num_filters in extra.STAGE4.NUM_BLOCKS)
        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
            height=cfg.MODEL.IMAGE_SIZE[1],
            width=cfg.MODEL.IMAGE_SIZE[0],
            name=name,
            deconv_suffix=deconv_suffix)    
    elif name in ['temp_tcn']:
        name = '{model}_{num_layers}'.format(
            model=name,
            num_layers=len(extra.STAGE))
        deconv_suffix = ''.join(
            'd{}'.format(num_filters)
            for num_filters in extra.STAGE)
        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
            height=cfg.MODEL.IMAGE_SIZE[1],
            width=cfg.MODEL.IMAGE_SIZE[0],
            name=name,
            deconv_suffix=deconv_suffix)
    elif name in ['pose_unet']:
        name = '{model}_4stage'.format(
            model=name)
        deconv_suffix = '2up'
        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
            height=cfg.MODEL.IMAGE_SIZE[1],
            width=cfg.MODEL.IMAGE_SIZE[0],
            name=name,
            deconv_suffix=deconv_suffix)
    elif name in ['pose_encoderdecoder']:
        name = '{model}_4stage'.format(
            model=name)
        deconv_suffix = '4up'
        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
            height=cfg.MODEL.IMAGE_SIZE[1],
            width=cfg.MODEL.IMAGE_SIZE[0],
            name=name,
            deconv_suffix=deconv_suffix)
    elif name in ['pose_asymmetrical_encodeco']:
        name = '{model}_4stage'.format(
            model=name)
        deconv_suffix = '2up'
        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
            height=cfg.MODEL.IMAGE_SIZE[1],
            width=cfg.MODEL.IMAGE_SIZE[0],
            name=name,
            deconv_suffix=deconv_suffix)
    elif name in ['pose_resnet_aps']:
        name = '{model}_{num_layers}'.format(
            model=name,
            num_layers=extra.NUM_LAYERS)
        deconv_suffix = ''.join(
            'd{}'.format(num_filters)
            for num_filters in extra.NUM_DECONV_FILTERS)
        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
            height=cfg.MODEL.IMAGE_SIZE[1],
            width=cfg.MODEL.IMAGE_SIZE[0],
            name=name,
            deconv_suffix=deconv_suffix)
    else:
        raise ValueError('Unkown model: {}'.format(cfg.MODEL))

    return name, full_name


if __name__ == '__main__':
    import sys
    gen_config(sys.argv[1])
